rm(list = ls())

## 

library(tidyverse)
library(rdrobust)
library(mediation)
library(pbapply)

## Declare covars

covs <- c('turnout_party_2009_btw', 'soz_vers_beschaeftigte_share',
          'pop_density_km2', 'migration_out_share')

## Get municipal finance data

census <- read_rds('data/muni_finance.rds')%>% 
  mutate(state_id = factor(state_id))

## Rename states

levels(census$state_id) <- c('S-H', 'HH', 'NDS', 'Bremen', 'NRW',
                             'Hessen', 'RP', 'BW', 'Bayern', 'Saarland',
                             'Berlin', 'Brandenburg', 'M-V', 'Sachsen',
                             'Sachsen-Anhalt', 'Thueringen')

## Divide all spending variables by population

census <- census %>% 
  mutate(across(matches("bt_spending_total", ~./bt_pop_2012)))

## Convert to long

census_long  <- census %>% 
  pivot_longer(cols = matches('bt_spending_total'), names_sep = '_(?!.*_)',
               names_to = c('name', 'year')) %>% 
  pivot_wider(names_from = 'name', values_from = 'value') %>% 
  mutate(year = as.numeric(year)) %>%
  filter(between(pop_dec_09, 0, 20000))

## Get info on states

states <- read_rds("data/states_census.rds") %>%
  dplyr::select(state_name, applies_census, census_first_year) %>%
  mutate(state_name = dplyr::recode(state_name, `Schleswig-Holstein` = 'S-H',
                                    `Niedersachsen` = 'NDS',
                                    `Nordrhein-Westfalen` = 'NRW',
                                    `Thüringen` = 'Thueringen',
                                    `Baden-Württemberg` = 'BW',
                                    `Mecklenburg-Vorpommern` = 'M-V',
                                    `Rheinland-Pfalz` = 'RP'))

## Merge states to main data 

census_long <- left_join(census_long, states, 
                     by = c('state_id' = 'state_name')) 

## Time relative to treatment

census_long <- census_long %>%
  filter(!is.na(year)) %>% 
  mutate(time_rel = year - census_first_year,
         time_rel = ifelse(time_rel > -1, time_rel + 1, time_rel))

## Function the generates first-differenced DF 

gen_fd <- function(period1, period2, outcome) {
  
  df_temp <- census_long
  df_temp[, 'outcome'] <- df_temp[, outcome]
  
  df_temp <- df_temp %>% 
    mutate(state_id = substr(ags, 1 ,2)) %>% 
    filter(time_rel %in% c(period1, period2)) %>% 
    mutate(time_rel = case_when(time_rel == period1 ~ 1,
                                time_rel == period2 ~ 2)) %>% 
    pivot_wider(values_from = 'outcome', 
                names_from = time_rel,
                names_prefix = 'period_',
                id_cols = c('pop_dec_09', 'ags', 'applies_census', 'state_id')) %>% 
    mutate(outcome_diff = period_2 - period_1) %>% 
    mutate(state_id = substr(ags, 1, 2)) %>%
    mutate(runvar = (pop_dec_09 * -1) + 10000) %>% 
    filter(!is.na(outcome_diff))
  
  ## Return 
  
  df_temp
}

## Federal election data

bt <- readRDS('data/data_federal.rds') %>%
  filter(!is.na(treated)) %>% 
  filter(year > 2012)

## Declare list of outcomes

outcomes <- c('turnout_party',
              'agg_left_party', 
              'agg_center_party', 
              'agg_right_party')

## First differences

diff_df <- pblapply(outcomes, function(o) {
  out <- bt %>%
    filter(year > 2012) %>%
    pivot_wider(values_from = o, names_from = 'year', id_cols = 'ags',
                names_prefix = 'o') %>%
    mutate(diff = o2017  - o2013) %>%  dplyr::select(ags, diff) 
  ## Rename
  colnames(out)[2] <- o
  
  ## Return this
  out
}) %>%
  reduce(left_join) %>%
  left_join(bt %>% dplyr::select(ags, pop_dec_09, applies_census,
                                 one_of(covs), state_id) %>%
              distinct(ags, .keep_all = T)) %>%
  mutate(runvar = (pop_dec_09 * -1) + 10000) %>%
  filter(applies_census == 1)

##

treatments <- c('bt_spending_total')

## 

t = treatments[1]
exclude_state = T

## Triangular weight / kernel function (to approximate what RD does)

tweight = function(x) {1 - (abs(x) / max(abs(x)))}

## Estimate   

df_use <- gen_fd(period1 = -3, 
                 period2 = 3, 
                 outcome = "bt_spending_total")

cat("bt_spending_total", exclude_state, -3 ,3)

## Merge election results

df_use <- diff_df %>% 
  dplyr::select(ags, agg_left_party, one_of(covs)) %>% 
  left_join(df_use, .) %>% 
  filter(!is.na(agg_left_party)) %>% 
  mutate(treat = ifelse(runvar > -1, 1, 0))

## Select sample (B-W excluded, complete covars)

subset_select <- !df_use$state_id == '08' & complete.cases(df_use[, covs])

# Get bandwidth

bw_use <- rdbwselect(y = df_use[subset_select, ] %>% pull(agg_left_party),
                     x = df_use$runvar[subset_select],
                     covs = as.matrix(df_use[subset_select, covs]),
                     c = 0)
bw_use <- bw_use$bws[1]

## Within bandwidth, also gen weights

subset_select <- subset_select & (abs(df_use$runvar) < (ceiling(bw_use) + 1))
df_use[subset_select, 'w'] <- tweight(df_use[subset_select,] %>% pull(runvar))

## Models

m_med <- lm(outcome_diff ~ treat + runvar + runvar*treat +
              turnout_party_2009_btw + soz_vers_beschaeftigte_share + 
              pop_density_km2 + migration_out_share, 
            data=df_use[subset_select, ], 
            weights = df_use[subset_select, ] %>% pull(w))
m_main <- lm(agg_left_party ~ treat + runvar + runvar*treat + outcome_diff + turnout_party_2009_btw + 
               soz_vers_beschaeftigte_share + pop_density_km2 + migration_out_share, 
             data=df_use[subset_select, ], 
             weights = df_use[subset_select, ] %>% pull(w))

## Mediation

set.seed(1)

m <- mediate(model.m = m_med,
             model.y = m_main,
             sims=50, 
             treat="treat", 
             mediator="outcome_diff",
             covariates = df_use[subset_select, covs])

##

out_tot <- data.frame(estimate = m$tau.coef[1], 
                      conf.low = m$tau.ci[1],
                      conf.high = m$tau.ci[2],
                      p.value = m$tau.p[1],
                      outcome = t, 
                      method = 'Total')

out_acme <- data.frame(estimate = m$d0[1], 
                       conf.low = m$d0.ci[1],
                       conf.high = m$d0.ci[2],
                       p.value = m$d0.p[1],
                       outcome = t, 
                       method = 'ACME')

out_ade <- data.frame(estimate = m$z0[1], 
                      conf.low = m$z0.ci[1],
                      conf.high = m$z0.ci[2],
                      p.value = m$z0.p[1],
                      outcome = t, 
                      method = 'ADE')

out_prop <- data.frame(estimate = m$n0[1], 
                       conf.low = m$n0.ci[1],
                       conf.high = m$n0.ci[2],
                       p.value = m$n0.p[1],
                       outcome = t, 
                       method = 'Prop. mediated')


## Combine
out2 <-rbind(out_tot, out_prop, out_acme, out_ade) %>% 
  mutate(n = m$nobs,
         bw = bw_use)


## Some renaming

res <- out2 %>% 
  mutate(outcome = recode(outcome, 
                          `bt_spending_total` = 'Spending'))

## Prep for table

tab <- res %>% 
  mutate(bw = round(bw, 0),
         p.value = round(p.value, 3),
         estimate = round(estimate, 3)) %>% 
  mutate(ci = paste0('[', round(conf.low, 3), ', ', round(conf.high, 3), ']')) %>% 
  dplyr::select(method, estimate, ci, bw, n) %>% 
  slice(1, 4,3,2)

# Table 4: mediation ----

kable(tab, "latex", longtable = F, 
      booktabs = T, col.names = c('',
                                  'Estimate',
                                  'CI', 
                                  '$h_{\\text{MSE}}$',
                                  '$n$'),
      linesep = "",
      caption = 'Mediation results\\label{tab:mediation}',
      escape = F) %>%
  kable_styling(latex_options = c("repeat_header")) %>% 
  # collapse_rows() %>% 
  row_spec(0, bold = T) %>% 
  kable_styling(latex_options = "HOLD_position") %>% 
  footnote(general = "Stuff")

